Causal Inference
Simulate Data
- Assume a severity variable increases the treatment selection by an odds ratio of 1.5
- The baseline probability of getting treatment is 0.2
- The severity variable also increase the outcome risk by an odds ratio of 2
- The treatment reduces the outcome risk by by an odds ratio of 0.5
- The baseline probaiblity of outcome (death) is 0.2
n <- 500
dt <- data.table(severity = rnorm(n), covar1 = rnorm(n), covar2 = rnorm(n), covar3 = rnorm(n))
p0 <- 0.2
or_trt <- 1.5
dt <- dt[, odds_trt_0 := p0 / (1 - p0)
][, log_odds_trt := log(odds_trt_0) + severity * log(or_trt)
][, p_trt := exp(log_odds_trt) / (1 + exp(log_odds_trt))]
vsample <- function(p){
sample(c(1, 0), size = 1, replace = TRUE, prob = c(p, 1 - p))
}
vsample <- Vectorize(vsample)
dt <- dt[, trt := vsample(p_trt)]
p <- 0.2
dt <- dt[, odds_outcome_0 := p / (1 - p)
][, log_odds_outcome := log(odds_outcome_0) + severity * log(0.5) + covar1 * log(2.5) + covar2 * log(1.5)
][, p_outcome := exp(log_odds_outcome) / (1 + exp(log_odds_outcome))]
dt <- dt[, outcome := vsample(p_outcome)]Treatment Model
m_trt <- glm(trt ~ severity + covar1 + covar2 + covar3, data = dt, family = binomial(logit))
library(sjPlot)
tab_model(m_trt)| trt | |||
|---|---|---|---|
| Predictors | Odds Ratios | CI | p |
| (Intercept) | 0.23 | 0.19 – 0.29 | <0.001 |
| severity | 1.41 | 1.13 – 1.78 | 0.003 |
| covar1 | 1.10 | 0.88 – 1.38 | 0.388 |
| covar2 | 1.06 | 0.85 – 1.32 | 0.587 |
| covar3 | 0.98 | 0.78 – 1.23 | 0.836 |
| Observations | 500 | ||
| R2 Tjur | 0.020 | ||
Distribution of Propensity Score
dt$pred <- m_trt$fitted.values
p1 <- dt[trt == 1]$pred
p0 <- dt[trt == 0]$pred
plot_ly(alpha = 0.5, xbins = list(start = 0, end = 1, size = 0.02)) %>%
add_histogram(x = ~ p1
, name = "Treated"
, inherit = TRUE
## , xbins = seq(0, 1, 0.05)
) %>%
add_histogram(x = ~ p0
, name = "Not Treated"
, inherit = TRUE
## , xbins = seq(0, 1, 0.05)
) %>%
layout(barmode = "overlay"
, xaxis = list(title = paste0("Predicted Probabilities of Being Treated"),
zeroline = FALSE),
yaxis = list(title = "Count",
zeroline = FALSE))ggplot(data=dt, aes(x=pred, colour=factor(trt))) +
geom_density(alpha=0.5) +
labs(y="Density",x="pred") +
scale_y_continuous(breaks=NULL,label=c("")) +
theme(legend.position=c(0.8,0.2))Multivarible Outcome Model
m_outcome_mv <- glm(outcome ~ severity + trt + covar1 + covar2 + covar3, data = dt, family = binomial(logit))
tab_model(m_outcome_mv)| outcome | |||
|---|---|---|---|
| Predictors | Odds Ratios | CI | p |
| (Intercept) | 0.27 | 0.20 – 0.35 | <0.001 |
| severity | 0.52 | 0.40 – 0.67 | <0.001 |
| trt | 0.70 | 0.36 – 1.30 | 0.270 |
| covar1 | 2.67 | 2.07 – 3.49 | <0.001 |
| covar2 | 1.48 | 1.17 – 1.88 | 0.001 |
| covar3 | 1.00 | 0.80 – 1.26 | 0.989 |
| Observations | 500 | ||
| R2 Tjur | 0.214 | ||
Causal Forest
- grf runs regression forest
library(grf)
frml <- Wu::wu_formula(outcome = "", predictors = c("covar1", "covar2", "covar3", "trt", "severity"))
mmx <- model.matrix.lm(frml, data = dt)
X <- subset(mmx, select=-c(trt))
X <- subset(X, select = -c(1))
Y <- dt$outcome
W <- mmx[, "trt"]
rf <- causal_forest(X, Y, W
## , clusters = dt[["department"]]
, Y.hat = NULL
, W.hat = NULL
, num.trees = 1000
, honesty = TRUE
, tune.parameters = "none"
## , tune.parameters = c("sample.fraction", "mtry")
## , tune.num.trees = 4000
, compute.oob.predictions = TRUE
, seed = 123456
)
ate <- average_treatment_effect(rf, method="AIPW", target.sample = c("all"))
ate_treated <- average_treatment_effect(rf, method="AIPW", target.sample = c("treated"))
rst <- rbind(c("Average Effect", ate)
, c("Average Effect on Treated", ate_treated)
)
rst <- as.data.table(rst)
colnames(rst) <- c("name", "estimate", "se")
rst <- rst[, estimate := as.numeric(estimate)
][, se := as.numeric(se)
][, lower := estimate - (qnorm(0.975) * se)
][, upper := estimate + qnorm(0.975) * se]
rst %>% prt(digits=3)| name | estimate | se | lower | upper |
|---|---|---|---|---|
| Average Effect | -0.047 | 0.044 | -0.132 | 0.039 |
| Average Effect on Treated | -0.050 | 0.042 | -0.132 | 0.032 |
R sessionInfo
R version 4.2.0 (2022-04-22) Platform: x86_64-pc-linux-gnu (64-bit) Running under: Ubuntu 20.04.3 LTS
Matrix products: default BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0 LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
locale: [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
[4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
[7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
attached base packages: [1] stats graphics grDevices utils datasets methods base
other attached packages: [1] grf_2.1.0 sjPlot_2.8.10 Wu_0.0.0.9000
[4] flexdashboard_0.5.2 lme4_1.1-29 Matrix_1.4-0
[7] mgcv_1.8-38 nlme_3.1-152 png_0.1-7
[10] scales_1.2.0 nnet_7.3-16 labelled_2.9.1
[13] kableExtra_1.3.4 plotly_4.10.0 gridExtra_2.3
[16] ggplot2_3.3.6 DT_0.23 tableone_0.13.2
[19] magrittr_2.0.3 lubridate_1.8.0 dplyr_1.0.9
[22] plyr_1.8.7 data.table_1.14.2 rmdformats_1.0.4
[25] knitr_1.39
loaded via a namespace (and not attached): [1] TH.data_1.1-1 minqa_1.2.4 colorspace_2.0-3 ellipsis_0.3.2
[5] sjlabelled_1.2.0 estimability_1.4 parameters_0.18.1 rstudioapi_0.13
[9] farver_2.1.1 fansi_1.0.3 mvtnorm_1.1-3 xml2_1.3.3
[13] codetools_0.2-18 splines_4.2.0 sjmisc_2.8.9 jsonlite_1.8.0
[17] nloptr_2.0.3 ggeffects_1.1.2 broom_0.8.0 effectsize_0.7.0 [21] compiler_4.2.0 httr_1.4.3 sjstats_0.18.1 emmeans_1.7.5
[25] backports_1.4.1 assertthat_0.2.1 fastmap_1.1.0 lazyeval_0.2.2
[29] survey_4.1-1 cli_3.3.0 htmltools_0.5.3 tools_4.2.0
[33] coda_0.19-4 gtable_0.3.0 glue_1.6.2 Rcpp_1.0.9
[37] jquerylib_0.1.4 vctrs_0.4.1 svglite_2.1.0 crosstalk_1.2.0
[41] insight_0.18.0 xfun_0.31 stringr_1.4.0 rvest_1.0.2
[45] lifecycle_1.0.1 klippy_0.0.0.9500 MASS_7.3-54 zoo_1.8-10
[49] hms_1.1.1 sandwich_3.0-2 yaml_2.3.5 sass_0.4.1
[53] stringi_1.7.8 highr_0.9 bayestestR_0.12.1 boot_1.3-28
[57] rlang_1.0.4 pkgconfig_2.0.3 systemfonts_1.0.4 evaluate_0.15
[61] lattice_0.20-45 purrr_0.3.4 htmlwidgets_1.5.4 labeling_0.4.2
[65] tidyselect_1.1.2 bookdown_0.27 R6_2.5.1 generics_0.1.3
[69] multcomp_1.4-19 DBI_1.1.2 pillar_1.8.0 haven_2.5.0
[73] withr_2.5.0 survival_3.2-13 datawizard_0.4.1 tibble_3.1.8
[77] performance_0.9.1 modelr_0.1.8 utf8_1.2.2 rmarkdown_2.14
[81] grid_4.2.0 forcats_0.5.1 digest_0.6.29 webshot_0.5.3
[85] xtable_1.8-4 tidyr_1.2.0 munsell_0.5.0 viridisLite_0.4.0 [89] bslib_0.3.1 mitools_2.4